package main

import (
	"fmt"
	"os"
	"syscall"
	"unsafe"

	io_uring "gitee.com/childewang/iouring-go"
	iouring_syscall "gitee.com/childewang/iouring-go/syscall"
)

const entries uint = 16
const blockSize int64 = 32 * 1024

type ioData struct {
	read        int
	firstOffset uint64
	offset      uint64
	firstLen    uint64
	iov         []syscall.Iovec
}

func queueRead(iouring *io_uring.IOUring, size uint64, offset uint64, infd int) bool {
	data := new(ioData)

	sqe := iouring.GetSQE()
	if sqe == nil {
		return false
	}

	data.offset = offset
	data.firstOffset = offset

	bs := new([]byte)

	iovs := []syscall.Iovec{
		{Base: (*byte)(unsafe.Pointer(bs)),
			Len: uint64(size)},
	}
	data.iov = iovs
	data.firstLen = size

	// fmt.Printf("data: %+v\n", data)

	sqe.PrepReadv(infd, iovs, 1, int64(offset))
	sqe.SetData(uint64(uintptr(unsafe.Pointer(data))))

	return true
}

func queueWrite(iouring *io_uring.IOUring, size uint64, iovs []syscall.Iovec, offset uint64, outfd int) int {
	data := new(ioData)

	sqe := iouring.GetSQE()
	if sqe == nil {
		return 0
	}

	data.offset = offset
	data.firstOffset = offset

	data.iov = iovs
	data.firstLen = size

	data.iov = iovs

	sqe.PrepWritev(outfd, iovs, 1, int64(offset))
	sqe.SetData(uint64(uintptr(unsafe.Pointer(data))))

	return 0
}

func queuePrepped(iouring *io_uring.IOUring, data *ioData) {

}

func copyFile(src, dest *os.File) {
	params := iouring_syscall.IOURingParams{
		Flags: uint32(0),
	}

	iouring, err := io_uring.New(entries, &params)
	if err != nil {
		panic(fmt.Sprintf("new IOURing error: %v", err))
	}
	defer iouring.Close()

	stat, err := src.Stat()
	if err != nil {
		panic(err)
	}
	size := stat.Size()

	var reads int
	var writes int
	var offset uint64

	// fmt.Println("file size：", size)

	i := 0
	hadReads := reads
	for size >= 0 {
		// fmt.Println("read count: ", i)
		i++
		thisSize := size
		if reads+writes >= int(entries) {
			break
		}
		if thisSize > blockSize {
			thisSize = blockSize
		} else if thisSize <= 0 {
			break
		}
		ok := queueRead(iouring, uint64(thisSize), offset, int(src.Fd()))
		// fmt.Println("queueRead：", ok)
		if !ok {
			break
		}

		size -= thisSize
		offset += uint64(thisSize)
		reads++
		hadReads++
	}

	// fmt.Println("hadReads: ", hadReads)
	// fmt.Println("reads: ", reads)

	_, err = iouring.Submit()
	// fmt.Println("submitted111111111: ", submitted, err)

	for i = 0; i < hadReads; i++ {
		cqe, err := iouring.GetCQE(true)
		if err != nil {
			// fmt.Println(err)
			break
		}
		// fmt.Printf("%+v\n", cqe)
		data := (*ioData)(unsafe.Pointer(uintptr(cqe.UserData)))
		// fmt.Println(data.offset)
		queueWrite(iouring, uint64(size), data.iov, data.offset, int(dest.Fd()))
	}
	_, err = iouring.Submit()
	// fmt.Println("submitted222222222: ", submitted, err)

	for i = 0; i < hadReads; i++ {
		_, err := iouring.GetCQE(true)
		if err != nil {
			// fmt.Println(err)
			continue
		}
		// fmt.Printf("%d: %+v\n", i, cqe)
	}
}

func main() {
	if len(os.Args) != 3 {
		fmt.Printf("Usage: %s file1 file2\n", os.Args[0])
		return
	}

	src, err := os.Open(os.Args[1])
	if err != nil {
		fmt.Printf("Open src file failed: %v\n", err)
		return
	}
	defer src.Close()

	dest, err := os.Create(os.Args[2])
	if err != nil {
		fmt.Printf("create dest file failed: %v\n", err)
		return
	}
	defer dest.Close()

	copyFile(src, dest)
}
